{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# sklearn学习笔记\n",
    "- 代码实战\n",
    "\n",
    "## 统计信息\n",
    "- 两大类：定性分析和定量分析\n",
    "- 最大、最小、均值、标准差、方差、极差、中位数\n",
    "---\n",
    "其他\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "array([3, 2, 3, 4, 4, 1, 2, 4, 4, 3])\n",
      "[1, 5, 8, 2, 4, 1]\n",
      "均值：3.5,最小值:1,最大值:8,方差:6.25,标准差:2.5,中位数:3.0,极差:7\n",
      "7.5 1.0\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import sklearn as sk\n",
    "data = np.random.normal(1,5,10) # 正态分布\n",
    "data = np.random.randint(1,5,10) # 均匀分布\n",
    "print(repr(data))\n",
    "data = [1,5,8,2,4,1]\n",
    "print(data)\n",
    "# 各种统计指标\n",
    "print('均值：%s,最小值:%s,最大值:%s,方差:%s,标准差:%s,中位数:%s,极差:%s'%(np.mean(data),np.min(data),np.max(data),np.var(data),np.std(data),np.median(data),np.ptp(data)))\n",
    "print(np.cov(data),np.corrcoef(data)) # 协方差，相关系数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据集\n",
    "- 使用默认的iris数据集\n",
    "- 注意：内置数据集以bunch类型存在，包含data、target等成员"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
       "       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
       "       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "iris = load_iris()#导入IRIS数据集,sklearn.utils.Bunch类型\n",
    "#data = sk.datasets.load_iris()\n",
    "iris.data#特征矩阵\n",
    "iris.target#目标向量"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据预处理\n",
    "- 消除量纲：针对列（特征），将各个特征调整为标准正态分布，具体方法：Z标准化、极差标准化\n",
    "- 规范化：针对行（样本），将每个样本的长度限定为单位向量\n",
    "- 连续特征离散化：二值化\n",
    "- 离散特征数值化：one-hot编码（哑编码）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.90068117  1.03205722 -1.3412724  -1.31297673]\n",
      " [-1.14301691 -0.1249576  -1.3412724  -1.31297673]\n",
      " [-1.38535265  0.33784833 -1.39813811 -1.31297673]\n",
      " [-1.50652052  0.10644536 -1.2844067  -1.31297673]\n",
      " [-1.02184904  1.26346019 -1.3412724  -1.31297673]]\n"
     ]
    }
   ],
   "source": [
    "#预处理\n",
    "#dir(sk.preprocessing)\n",
    "new_data = sk.preprocessing.StandardScaler().fit_transform(iris.data)\n",
    "#除了Z标准化StandardScaler，还有其他标准化方法：MinMaxScaler极差标准化、规范化Normalizer\n",
    "#注：Z标准化只对列做处理（去量纲影响），而规范化将行处理为单位向量\n",
    "bin_data = sk.preprocessing.Binarizer(threshold=3).fit_transform(iris.data) # 二值化，阈值为3\n",
    "one_data = sk.preprocessing.OneHotEncoder().fit_transform(iris.target.reshape((-1,1))) # 独热编码\n",
    "print(new_data[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  (0, 1)\t1.0\n",
      "  (1, 2)\t1.0\n",
      "  (2, 0)\t1.0\n"
     ]
    }
   ],
   "source": [
    "data = np.array([[1,'a',10],[2,'b',-10],[4,'D',20],[2,'Z',30]])\n",
    "data_y = np.array([i[0] for i in data])\n",
    "#one-hot编码，输入必须是array类型，而且是列向量(正整数),按照取值升序排列，输出三元组形式表示的对角矩阵\n",
    "one_data = sk.preprocessing.OneHotEncoder().fit_transform(data_y.reshape(-1,1))\n",
    "#print(one_data)\n",
    "test = np.array([2,7,1]).reshape(-1,1)\n",
    "out = sk.preprocessing.OneHotEncoder().fit_transform(test.reshape(-1,1))\n",
    "#out = sk.preprocessing.OneHotEncoder().fit_transform(iris.target.reshape(-1,1))\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[   2.25, -230.  ,   12.5 ],\n",
       "       [   1.  , -200.  ,   10.  ],\n",
       "       [   2.  , -230.  ,  -10.  ],\n",
       "       [   4.  , -300.  ,   20.  ],\n",
       "       [   2.  , -157.  ,   30.  ]])"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = np.array([[1,-200,10],[2,-230,-10],[4,-300,20],[2,-157,30]])\n",
    "# 默认用均值自动填充缺失值\n",
    "sk.preprocessing.Imputer().fit_transform(np.vstack((np.array([np.nan,-230,np.nan]),data)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据变换\n",
    "- 多项式、对数、指数\n",
    "- 也可以自定义函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.3978952728\n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "#多项式转换,参数degree为度，默认值为2\n",
    "PolynomialFeatures().fit_transform(iris.data)\n",
    "print(np.log1p(10))\n",
    "# 也可以自定义\n",
    "from numpy import log1p # log(1+x)\n",
    "from sklearn.preprocessing import FunctionTransformer\n",
    "#自定义转换函数为对数函数的数据变换\n",
    "#第一个参数是单变元函数\n",
    "FunctionTransformer(log1p).fit_transform(iris.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 特征选择\n",
    "从发散性和相关性两个维度来筛选特征，主要方法：\n",
    "- filter：过滤法\n",
    "- wrapper：包装法\n",
    "- embedded：嵌入法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 1.4],\n",
       "       [ 1.4],\n",
       "       [ 1.3],\n",
       "       [ 1.5],\n",
       "       [ 1.4],\n",
       "       [ 1.7],\n",
       "       [ 1.4],\n",
       "       [ 1.5],\n",
       "       [ 1.4],\n",
       "       [ 1.5],\n",
       "       [ 1.5],\n",
       "       [ 1.6],\n",
       "       [ 1.4],\n",
       "       [ 1.1],\n",
       "       [ 1.2],\n",
       "       [ 1.5],\n",
       "       [ 1.3],\n",
       "       [ 1.4],\n",
       "       [ 1.7],\n",
       "       [ 1.5],\n",
       "       [ 1.7],\n",
       "       [ 1.5],\n",
       "       [ 1. ],\n",
       "       [ 1.7],\n",
       "       [ 1.9],\n",
       "       [ 1.6],\n",
       "       [ 1.6],\n",
       "       [ 1.5],\n",
       "       [ 1.4],\n",
       "       [ 1.6],\n",
       "       [ 1.6],\n",
       "       [ 1.5],\n",
       "       [ 1.5],\n",
       "       [ 1.4],\n",
       "       [ 1.5],\n",
       "       [ 1.2],\n",
       "       [ 1.3],\n",
       "       [ 1.5],\n",
       "       [ 1.3],\n",
       "       [ 1.5],\n",
       "       [ 1.3],\n",
       "       [ 1.3],\n",
       "       [ 1.3],\n",
       "       [ 1.6],\n",
       "       [ 1.9],\n",
       "       [ 1.4],\n",
       "       [ 1.6],\n",
       "       [ 1.4],\n",
       "       [ 1.5],\n",
       "       [ 1.4],\n",
       "       [ 4.7],\n",
       "       [ 4.5],\n",
       "       [ 4.9],\n",
       "       [ 4. ],\n",
       "       [ 4.6],\n",
       "       [ 4.5],\n",
       "       [ 4.7],\n",
       "       [ 3.3],\n",
       "       [ 4.6],\n",
       "       [ 3.9],\n",
       "       [ 3.5],\n",
       "       [ 4.2],\n",
       "       [ 4. ],\n",
       "       [ 4.7],\n",
       "       [ 3.6],\n",
       "       [ 4.4],\n",
       "       [ 4.5],\n",
       "       [ 4.1],\n",
       "       [ 4.5],\n",
       "       [ 3.9],\n",
       "       [ 4.8],\n",
       "       [ 4. ],\n",
       "       [ 4.9],\n",
       "       [ 4.7],\n",
       "       [ 4.3],\n",
       "       [ 4.4],\n",
       "       [ 4.8],\n",
       "       [ 5. ],\n",
       "       [ 4.5],\n",
       "       [ 3.5],\n",
       "       [ 3.8],\n",
       "       [ 3.7],\n",
       "       [ 3.9],\n",
       "       [ 5.1],\n",
       "       [ 4.5],\n",
       "       [ 4.5],\n",
       "       [ 4.7],\n",
       "       [ 4.4],\n",
       "       [ 4.1],\n",
       "       [ 4. ],\n",
       "       [ 4.4],\n",
       "       [ 4.6],\n",
       "       [ 4. ],\n",
       "       [ 3.3],\n",
       "       [ 4.2],\n",
       "       [ 4.2],\n",
       "       [ 4.2],\n",
       "       [ 4.3],\n",
       "       [ 3. ],\n",
       "       [ 4.1],\n",
       "       [ 6. ],\n",
       "       [ 5.1],\n",
       "       [ 5.9],\n",
       "       [ 5.6],\n",
       "       [ 5.8],\n",
       "       [ 6.6],\n",
       "       [ 4.5],\n",
       "       [ 6.3],\n",
       "       [ 5.8],\n",
       "       [ 6.1],\n",
       "       [ 5.1],\n",
       "       [ 5.3],\n",
       "       [ 5.5],\n",
       "       [ 5. ],\n",
       "       [ 5.1],\n",
       "       [ 5.3],\n",
       "       [ 5.5],\n",
       "       [ 6.7],\n",
       "       [ 6.9],\n",
       "       [ 5. ],\n",
       "       [ 5.7],\n",
       "       [ 4.9],\n",
       "       [ 6.7],\n",
       "       [ 4.9],\n",
       "       [ 5.7],\n",
       "       [ 6. ],\n",
       "       [ 4.8],\n",
       "       [ 4.9],\n",
       "       [ 5.6],\n",
       "       [ 5.8],\n",
       "       [ 6.1],\n",
       "       [ 6.4],\n",
       "       [ 5.6],\n",
       "       [ 5.1],\n",
       "       [ 5.6],\n",
       "       [ 6.1],\n",
       "       [ 5.6],\n",
       "       [ 5.5],\n",
       "       [ 4.8],\n",
       "       [ 5.4],\n",
       "       [ 5.6],\n",
       "       [ 5.1],\n",
       "       [ 5.1],\n",
       "       [ 5.9],\n",
       "       [ 5.7],\n",
       "       [ 5.2],\n",
       "       [ 5. ],\n",
       "       [ 5.2],\n",
       "       [ 5.4],\n",
       "       [ 5.1]])"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.feature_selection import VarianceThreshold\n",
    "#方差选择法，返回值为特征选择后的数据,参数threshold为方差的阈值\n",
    "VarianceThreshold(threshold=3).fit_transform(iris.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "constant() missing 1 required positional argument: 'value'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<timed exec>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n",
      "\u001b[1;31mTypeError\u001b[0m: constant() missing 1 required positional argument: 'value'"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "import tensorflow as tf\n",
    "import time\n",
    "tf.constant()\n",
    "time.sleep(3)\n",
    "tf.constant\n",
    "print('hello')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n",
      "Extracting /tmp/mnist_tutorial/data\\train-images-idx3-ubyte.gz\n",
      "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n",
      "Extracting /tmp/mnist_tutorial/data\\train-labels-idx1-ubyte.gz\n",
      "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n",
      "Extracting /tmp/mnist_tutorial/data\\t10k-images-idx3-ubyte.gz\n",
      "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n",
      "Extracting /tmp/mnist_tutorial/data\\t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "module 'urllib' has no attribute 'urlretrieve'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-1-11fac8e4af38>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      9\u001b[0m \u001b[0mmnist\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcontrib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlearn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmnist\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread_data_sets\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_dir\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mLOGDIR\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;34m'data'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mone_hot\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     10\u001b[0m \u001b[1;31m### Get a sprite and labels file for the embedding projector ###\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m \u001b[0murllib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0murlretrieve\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mGIST_URL\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;34m'labels_1024.tsv'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mLOGDIR\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;34m'labels_1024.tsv'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     12\u001b[0m \u001b[0murllib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0murlretrieve\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mGIST_URL\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;34m'sprite_1024.png'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mLOGDIR\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;34m'sprite_1024.png'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     13\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAttributeError\u001b[0m: module 'urllib' has no attribute 'urlretrieve'"
     ]
    }
   ],
   "source": [
    "#https://gist.github.com/dandelionmane/4f02ab8f1451e276fea1f165a20336f1#file-slides-pdf\n",
    "import os\n",
    "import tensorflow as tf\n",
    "import urllib\n",
    "\n",
    "LOGDIR = '/tmp/mnist_tutorial/'\n",
    "GIST_URL = 'https://gist.githubusercontent.com/dandelionmane/4f02ab8f1451e276fea1f165a20336f1/raw/dfb8ee95b010480d56a73f324aca480b3820c180'\n",
    "\n",
    "### MNIST EMBEDDINGS ###\n",
    "mnist = tf.contrib.learn.datasets.mnist.read_data_sets(train_dir=LOGDIR + 'data', one_hot=True)\n",
    "### Get a sprite and labels file for the embedding projector ###\n",
    "urllib.urlretrieve(GIST_URL + 'labels_1024.tsv', LOGDIR + 'labels_1024.tsv')\n",
    "urllib.urlretrieve(GIST_URL + 'sprite_1024.png', LOGDIR + 'sprite_1024.png')\n",
    "\n",
    "\n",
    "def conv_layer(input, size_in, size_out, name=\"conv\"):\n",
    "  with tf.name_scope(name):\n",
    "    w = tf.Variable(tf.truncated_normal([5, 5, size_in, size_out], stddev=0.1), name=\"W\")\n",
    "    b = tf.Variable(tf.constant(0.1, shape=[size_out]), name=\"B\")\n",
    "    conv = tf.nn.conv2d(input, w, strides=[1, 1, 1, 1], padding=\"SAME\")\n",
    "    act = tf.nn.relu(conv + b)\n",
    "    tf.summary.histogram(\"weights\", w)\n",
    "    tf.summary.histogram(\"biases\", b)\n",
    "    tf.summary.histogram(\"activations\", act)\n",
    "    return tf.nn.max_pool(act, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding=\"SAME\")\n",
    "\n",
    "\n",
    "def fc_layer(input, size_in, size_out, name=\"fc\"):\n",
    "  with tf.name_scope(name):\n",
    "    w = tf.Variable(tf.truncated_normal([size_in, size_out], stddev=0.1), name=\"W\")\n",
    "    b = tf.Variable(tf.constant(0.1, shape=[size_out]), name=\"B\")\n",
    "    act = tf.nn.relu(tf.matmul(input, w) + b)\n",
    "    tf.summary.histogram(\"weights\", w)\n",
    "    tf.summary.histogram(\"biases\", b)\n",
    "    tf.summary.histogram(\"activations\", act)\n",
    "    return act\n",
    "\n",
    "\n",
    "def mnist_model(learning_rate, use_two_conv, use_two_fc, hparam):\n",
    "  tf.reset_default_graph()\n",
    "  sess = tf.Session()\n",
    "\n",
    "  # Setup placeholders, and reshape the data\n",
    "  x = tf.placeholder(tf.float32, shape=[None, 784], name=\"x\")\n",
    "  x_image = tf.reshape(x, [-1, 28, 28, 1])\n",
    "  tf.summary.image('input', x_image, 3)\n",
    "  y = tf.placeholder(tf.float32, shape=[None, 10], name=\"labels\")\n",
    "\n",
    "  if use_two_conv:\n",
    "    conv1 = conv_layer(x_image, 1, 32, \"conv1\")\n",
    "    conv_out = conv_layer(conv1, 32, 64, \"conv2\")\n",
    "  else:\n",
    "    conv1 = conv_layer(x_image, 1, 64, \"conv\")\n",
    "    conv_out = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding=\"SAME\")\n",
    "\n",
    "  flattened = tf.reshape(conv_out, [-1, 7 * 7 * 64])\n",
    "\n",
    "  if use_two_fc:\n",
    "    fc1 = fc_layer(flattened, 7 * 7 * 64, 1024, \"fc1\")\n",
    "    embedding_input = fc1\n",
    "    embedding_size = 1024\n",
    "    logits = fc_layer(fc1, 1024, 10, \"fc2\")\n",
    "  else:\n",
    "    embedding_input = flattened\n",
    "    embedding_size = 7*7*64\n",
    "    logits = fc_layer(flattened, 7*7*64, 10, \"fc\")\n",
    "\n",
    "  with tf.name_scope(\"xent\"):\n",
    "    xent = tf.reduce_mean(\n",
    "        tf.nn.softmax_cross_entropy_with_logits(\n",
    "            logits=logits, labels=y), name=\"xent\")\n",
    "    tf.summary.scalar(\"xent\", xent)\n",
    "\n",
    "  with tf.name_scope(\"train\"):\n",
    "    train_step = tf.train.AdamOptimizer(learning_rate).minimize(xent)\n",
    "\n",
    "  with tf.name_scope(\"accuracy\"):\n",
    "    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "    tf.summary.scalar(\"accuracy\", accuracy)\n",
    "\n",
    "  summ = tf.summary.merge_all()\n",
    "\n",
    "\n",
    "  embedding = tf.Variable(tf.zeros([1024, embedding_size]), name=\"test_embedding\")\n",
    "  assignment = embedding.assign(embedding_input)\n",
    "  saver = tf.train.Saver()\n",
    "\n",
    "  sess.run(tf.global_variables_initializer())\n",
    "  writer = tf.summary.FileWriter(LOGDIR + hparam)\n",
    "  writer.add_graph(sess.graph)\n",
    "\n",
    "  config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()\n",
    "  embedding_config = config.embeddings.add()\n",
    "  embedding_config.tensor_name = embedding.name\n",
    "  embedding_config.sprite.image_path = LOGDIR + 'sprite_1024.png'\n",
    "  embedding_config.metadata_path = LOGDIR + 'labels_1024.tsv'\n",
    "  # Specify the width and height of a single thumbnail.\n",
    "  embedding_config.sprite.single_image_dim.extend([28, 28])\n",
    "  tf.contrib.tensorboard.plugins.projector.visualize_embeddings(writer, config)\n",
    "\n",
    "  for i in range(2001):\n",
    "    batch = mnist.train.next_batch(100)\n",
    "    if i % 5 == 0:\n",
    "      [train_accuracy, s] = sess.run([accuracy, summ], feed_dict={x: batch[0], y: batch[1]})\n",
    "      writer.add_summary(s, i)\n",
    "    if i % 500 == 0:\n",
    "      sess.run(assignment, feed_dict={x: mnist.test.images[:1024], y: mnist.test.labels[:1024]})\n",
    "      saver.save(sess, os.path.join(LOGDIR, \"model.ckpt\"), i)\n",
    "    sess.run(train_step, feed_dict={x: batch[0], y: batch[1]})\n",
    "\n",
    "def make_hparam_string(learning_rate, use_two_fc, use_two_conv):\n",
    "  conv_param = \"conv=2\" if use_two_conv else \"conv=1\"\n",
    "  fc_param = \"fc=2\" if use_two_fc else \"fc=1\"\n",
    "  return \"lr_%.0E,%s,%s\" % (learning_rate, conv_param, fc_param)\n",
    "\n",
    "def main():\n",
    "  # You can try adding some more learning rates\n",
    "  for learning_rate in [1E-4]:\n",
    "\n",
    "    # Include \"False\" as a value to try different model architectures\n",
    "    for use_two_fc in [True]:\n",
    "      for use_two_conv in [True]:\n",
    "        # Construct a hyperparameter string for each one (example: \"lr_1E-3,fc=2,conv=2)\n",
    "        hparam = make_hparam_string(learning_rate, use_two_fc, use_two_conv)\n",
    "        print('Starting run for %s' % hparam)\n",
    "\n",
    "        # Actually run with the new settings\n",
    "        mnist_model(learning_rate, use_two_fc, use_two_conv, hparam)\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "  main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
